package com.neuralnetwork.sample.util;
import com.neuralnetwork.sample.constant.Constant;
import com.neuralnetwork.sample.model.ImageModel;
import com.neuralnetwork.sample.ui.Canvas;

import javax.imageio.ImageIO;
import javax.swing.BorderFactory;

import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/*
Copyright [2017] [Pi Jing]

        Licensed under the Apache License, Version 2.0 (the "License");
        you may not use this file except in compliance with the License.
        You may obtain a copy of the License at

        http://www.apache.org/licenses/LICENSE-2.0

        Unless required by applicable law or agreed to in writing, software
        distributed under the License is distributed on an "AS IS" BASIS,
        WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        See the License for the specific language governing permissions and
        limitations under the License.
*/
public class ImageUtil {
    private static ImageUtil imageUtil = null;
    private int smallWidth = Constant.smallWidth;
    private int smallCanvasWidth = Constant.smallCanvasWidth;
    private int smallHeight = Constant.smallHeight;

    private ImageUtil(){}

    public static ImageUtil getInstance(){
        if(imageUtil == null){
            imageUtil = new ImageUtil();
        }
        return imageUtil;
    }

    //list all jpg file in train folder
    public List<String> getImageList(){
        File file = new File(Constant.trainFolder);
        List<String> fileList = new ArrayList<String>();
        if(file.isDirectory()){
            File[] files = file.listFiles();
            for(File fileItem:files){
                if(fileItem.isFile() && fileItem.getAbsolutePath().endsWith(".jpg")){
                    fileList.add(fileItem.getAbsolutePath());
                }
            }
        }
        return fileList;
    }
    
    //list all jpg file in train folder
    public List<String> getMnistImageList(){
        File file = new File(Constant.mnistTrainFolder);
        List<String> fileList = new ArrayList<String>();
        if(file.isDirectory()){
            File[] files = file.listFiles();
            for(File fileItem:files){
                if(fileItem.isFile() && fileItem.getAbsolutePath().endsWith(".jpg")){
                    fileList.add(fileItem.getAbsolutePath());
                }
            }
        }
        return fileList;
    }

    //create image model list to record(number and gray value matrix)
    public List<ImageModel> getImageModel(List<String> imageList){
        List<ImageModel> list = new ArrayList<ImageModel>();
        for(String item:imageList){
            try {
                BufferedImage bimage = ImageIO.read(new File(item));
                //resize to 28*28
                Image smallImage = bimage.getScaledInstance(smallWidth, smallHeight, Image.SCALE_SMOOTH);
                BufferedImage bSmallImage = new BufferedImage(smallWidth,smallHeight,BufferedImage.TYPE_INT_RGB);
                Graphics graphics1 = bSmallImage.getGraphics();
                graphics1.drawImage(smallImage, 0, 0, null);
                graphics1.dispose();

                //get gray value
                int[] pixes = new int[smallWidth*smallHeight];
                double[] grayMatrix = new double[smallWidth*smallHeight];
                int index = -1;
                pixes = (int[])bSmallImage.getRaster().getDataElements(0,0,smallWidth,smallHeight,pixes);
                for(int i=0;i<smallWidth;i++){
                    for(int j=0;j<smallHeight;j++){
                        int rgb = pixes[i*smallWidth+j];
                        int r = (rgb & 0xff0000) >> 16;
                        int g = (rgb & 0xff00) >> 8;
                        int b = (rgb & 0xff);
                        double gray = Double.valueOf(r * 299 + g * 587 + b * 114 + 500)/255000.0;

                        grayMatrix[++index] = gray;
                    }
                }

                Integer digit = Integer.parseInt(new File(item).getName().split("_")[0]);
                ImageModel curModel = new ImageModel(grayMatrix, digit);
                list.add(curModel);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        return list;
    }
    
    public BufferedImage getInputImageFromPanel(com.neuralnetwork.sample.ui.Canvas canvas) throws Exception{
        Dimension imageSize = canvas.getSize();
        BufferedImage image = new BufferedImage(imageSize.width,imageSize.height, BufferedImage.TYPE_INT_RGB);
        Graphics2D graphics = image.createGraphics();
        canvas.paint(graphics);
        graphics.dispose();
        
        List<Boolean> horizonborders = ImageSplitUtil.findHorizonBorders(image);
        List<Boolean> vertialborders = ImageSplitUtil.findVerticalBorders(image);
        
        int outline0 = ImageSplitUtil.findStartBorder(horizonborders);
        int outline2 = ImageSplitUtil.findEndBorder(horizonborders);
        System.out.println(horizonborders.contains(false)+":"+vertialborders.contains(false));
        
        int outline1 = ImageSplitUtil.findStartBorder(vertialborders);
        int outline3 = ImageSplitUtil.findEndBorder(vertialborders);
        System.out.println("JerryDebug:outlines:"+outline0+":"+outline1+":"+outline2+":"+outline3);
        if (outline0<0||outline1<0||outline2<0||outline3<0) throw new Exception("Wrong image.");
        //cut
        return image.getSubimage(outline0,outline1,outline2-outline0,outline3-outline1);
    }
    
    public void drawOutLinesFromPanel(com.neuralnetwork.sample.ui.Canvas canvas) throws Exception{
        Dimension imageSize = canvas.getSize();
        BufferedImage image = new BufferedImage(imageSize.width,imageSize.height, BufferedImage.TYPE_INT_RGB);
        Graphics2D graphics = image.createGraphics();
        canvas.paint(graphics);
        graphics.dispose();
        
        List<Boolean> horizonborders = ImageSplitUtil.findHorizonBorders(image);
        List<Boolean> vertialborders = ImageSplitUtil.findVerticalBorders(image);
        
        int outline0 = ImageSplitUtil.findStartBorder(horizonborders);
        int outline2 = ImageSplitUtil.findEndBorder(horizonborders);
        System.out.println(horizonborders.contains(false)+":"+vertialborders.contains(false));
        
        int outline1 = ImageSplitUtil.findStartBorder(vertialborders);
        int outline3 = ImageSplitUtil.findEndBorder(vertialborders);
        System.out.println("JerryDebug:outlines:"+outline0+":"+outline1+":"+outline2+":"+outline3);
        if (outline0<0||outline1<0||outline2<0||outline3<0) throw new Exception("Wrong image.");
        drawRectangle(canvas,new int [] {outline0,outline1,outline2,outline3},Color.RED);
        List<Integer> vBorders = ImageSplitUtil.returnSplitedBorders(horizonborders);
        System.out.println("JerryDebug:vborders:"+vBorders.size());
        if (vBorders.size()%2 != 0) throw new Exception("Wrong split!");
        for (int i=0;i<vBorders.size();i+=2) {
        	drawRectangle(canvas,new int [] {vBorders.get(i)+2,outline1+2,vBorders.get(i+1)-2,outline3-2},Color.BLUE);
        }
        
    }
    
    public void drawRectangle(com.neuralnetwork.sample.ui.Canvas canvas,int[] rectangle,Color color) throws Exception{
    	if (rectangle.length != 4) throw new Exception("Wrong retangle!");
//    	canvas.setBorder(BorderFactory.createLineBorder(color));
//        canvas.setOutLine(rectangle[0], rectangle[1], rectangle[2]-rectangle[0], rectangle[3]-rectangle[1]);
//        canvas.setBorder(BorderFactory.createLineBorder(Color.GRAY));
        
        Graphics graphics = canvas.getGraphics();
        Graphics2D g2d = (Graphics2D)graphics;

        //background is white
        g2d.setBackground(Color.WHITE);

        //draw line
        g2d.setColor(color);
        g2d.setStroke(new BasicStroke(2));
	    g2d.drawRect(rectangle[0],rectangle[1],rectangle[2]-rectangle[0],rectangle[3]-rectangle[1]);
        g2d.dispose();
    }
    
    public static int[] getOutline(com.neuralnetwork.sample.ui.Canvas canvas){
    	double[] canvasRasterMatrix = ImageUtil.getInstance().getRasterMatrixFromCanvas(canvas);
        double[] grayMatrix = ImageUtil.getInstance().getRawGrayMatrixFromPanel(canvas);
        int[] binaryArray = ImageUtil.getInstance().transGrayToBinaryValue(grayMatrix);
        int minRow = Integer.MAX_VALUE;
        int maxRow = Integer.MIN_VALUE;
        int minCol = Integer.MAX_VALUE;
        int maxCol = Integer.MIN_VALUE;
        for(int i=0;i<binaryArray.length;i++){
            int row = i/108;
            int col = i%108;
            if(binaryArray[i] == 1){
                if(minRow > row){
                    minRow = row;
                }
                if(maxRow < row){
                    maxRow = row;
                }
                if(minCol > col){
                    minCol = col;
                }
                if(maxCol < col){
                    maxCol = col;
                }
            }
        }
        int len = Math.max((maxCol-minCol+1)*10, (maxRow-minRow+1)*10);
        canvas.setOutLine(minCol*10, minRow*10, len, len);

        return new int[]{minCol*10, minRow*10, len, len};
    }

    public double[] getGrayMatrixFromPanel(com.neuralnetwork.sample.ui.Canvas canvas, int[] outline){
        Dimension imageSize = canvas.getSize();
        BufferedImage image = new BufferedImage(imageSize.width,imageSize.height, BufferedImage.TYPE_INT_RGB);
        Graphics2D graphics = image.createGraphics();
        canvas.paint(graphics);
        graphics.dispose();

        //cut
        if(outline != null){
            if(outline[0] + outline[2] > canvas.getWidth()){
                outline[2] = canvas.getWidth()-outline[0];
            }
            if(outline[1] + outline[3] > canvas.getHeight()){
                outline[3] = canvas.getHeight()-outline[1];
            }
            System.out.println("JerryDebug:"+outline[0]+":"+outline[1]+":"+outline[2]+":"+outline[3]);
            image = image.getSubimage(outline[0],outline[1],outline[2],outline[3]);
        }
        //resize to 28*28
        Image smallImage = image.getScaledInstance(smallWidth, smallHeight, Image.SCALE_SMOOTH);
        BufferedImage bSmallImage = new BufferedImage(smallWidth,smallHeight,BufferedImage.TYPE_INT_RGB);
        Graphics graphics1 = bSmallImage.getGraphics();
        graphics1.drawImage(smallImage, 0, 0, null);
        graphics1.dispose();

        //get gray value
        int[] pixes = new int[smallWidth*smallHeight];
        double[] grayMatrix = new double[smallWidth*smallHeight];
        int index = -1;
        pixes = (int[])bSmallImage.getRaster().getDataElements(0,0,smallWidth,smallHeight,pixes);
        for(int i=0;i<smallWidth;i++){
            for(int j=0;j<smallHeight;j++){
                int rgb = pixes[i*smallWidth+j];
                int r = (rgb & 0xff0000) >> 16;
                int g = (rgb & 0xff00) >> 8;
                int b = (rgb & 0xff);
                double gray = Double.valueOf(r * 299 + g * 587 + b * 114 + 500)/255000.0;

                grayMatrix[++index] = gray;
            }
        }
        return grayMatrix;
    }
    
    
    public double[] getGrayMatrixFromBufferedImage(BufferedImage image) throws Exception{
    	
    	if (image.getHeight()>=image.getWidth()) {
	    	int smallCharWidth = (int)((smallWidth-2)*(image.getWidth()/(double)image.getHeight()));
	    	int startx = (int)((smallWidth-2-smallCharWidth)/2.0)+1;
	    	System.out.println("JerryDebug:"+smallCharWidth+":"+startx);
    	
	        //resize to 28*28
	        Image smallImage = image.getScaledInstance(smallCharWidth, smallHeight-2, Image.SCALE_SMOOTH);
	        BufferedImage bSmallImage = new BufferedImage(smallWidth,smallHeight,BufferedImage.TYPE_INT_RGB);
	        Graphics graphics1 = bSmallImage.getGraphics();
	        graphics1.setColor(Color.WHITE);
	        graphics1.fillRect(0, 0, smallWidth, smallHeight);   // 矩形填充
            
	        graphics1.drawImage(smallImage, startx, 1, null);
	        graphics1.dispose();
	        ImageSplitUtil.saveImage(bSmallImage,"bSmallImage"+System.currentTimeMillis());
	        
	        //get gray value
	        int[] pixes = new int[smallWidth*smallHeight];
	        double[] grayMatrix = new double[smallWidth*smallHeight];
	        int index = -1;
	        pixes = (int[])bSmallImage.getRaster().getDataElements(0,0,smallWidth,smallHeight,pixes);
	        for(int i=0;i<smallWidth;i++){
	            for(int j=0;j<smallHeight;j++){
	                int rgb = pixes[i*smallWidth+j];
	                int r = (rgb & 0xff0000) >> 16;
	                int g = (rgb & 0xff00) >> 8;
	                int b = (rgb & 0xff);
	                double gray = Double.valueOf(r * 299 + g * 587 + b * 114 + 500)/255000.0;

	                grayMatrix[++index] = gray;
	            }
	        }
	        return grayMatrix;
    	}else {
	    	int smallCharHeight = (int)(smallHeight*(image.getHeight()/(double)image.getWidth()));
	    	int starty = (int)((smallHeight-smallCharHeight)/2.0);
	    	System.out.println("JerryDebug:"+smallCharHeight+":"+starty);
    	
	        //resize to 28*28
	        Image smallImage = image.getScaledInstance(smallWidth, smallCharHeight, Image.SCALE_SMOOTH);
	        BufferedImage bSmallImage = new BufferedImage(smallWidth,smallHeight,BufferedImage.TYPE_INT_RGB);
	        Graphics graphics1 = bSmallImage.getGraphics();
	        
	        graphics1.setColor(Color.WHITE);
	        graphics1.fillRect(0, 0, smallWidth, smallHeight);   // 矩形填充
	        
	        graphics1.drawImage(smallImage, 0, starty, null);
	        graphics1.dispose();
	        ImageSplitUtil.saveImage(bSmallImage,"bSmallImage");
	        
	      //get gray value
	        int[] pixes = new int[smallWidth*smallHeight];
	        double[] grayMatrix = new double[smallWidth*smallHeight];
	        int index = -1;
	        pixes = (int[])bSmallImage.getRaster().getDataElements(0,0,smallWidth,smallHeight,pixes);
	        for(int i=0;i<smallWidth;i++){
	            for(int j=0;j<smallHeight;j++){
	                int rgb = pixes[i*smallWidth+j];
	                int r = (rgb & 0xff0000) >> 16;
	                int g = (rgb & 0xff00) >> 8;
	                int b = (rgb & 0xff);
	                double gray = Double.valueOf(r * 299 + g * 587 + b * 114 + 500)/255000.0;

	                grayMatrix[++index] = gray;
	            }
	        }
	        return grayMatrix;
    	}       
    }
    
    public double[] getRawGrayMatrixFromPanel(com.neuralnetwork.sample.ui.Canvas canvas){
        Dimension imageSize = canvas.getSize();
        BufferedImage image = new BufferedImage(imageSize.width,imageSize.height, BufferedImage.TYPE_INT_RGB);
        Graphics2D graphics = image.createGraphics();
        canvas.paint(graphics);
        graphics.dispose();

        //resize to 28*28
        Image smallImage = image.getScaledInstance(smallCanvasWidth, smallHeight, Image.SCALE_SMOOTH);
        BufferedImage bSmallImage = new BufferedImage(smallCanvasWidth,smallHeight,BufferedImage.TYPE_INT_RGB);
        Graphics graphics1 = bSmallImage.getGraphics();
        graphics1.drawImage(smallImage, 0, 0, null);
        graphics1.dispose();

        //get gray value
        int[] pixes = new int[smallCanvasWidth*smallHeight];
        double[] grayMatrix = new double[smallCanvasWidth*smallHeight];
        int index = -1;
        pixes = (int[])bSmallImage.getRaster().getDataElements(0,0,smallCanvasWidth,smallHeight,pixes);
        for(int i=0;i<smallCanvasWidth;i++){
            for(int j=0;j<smallHeight;j++){
                int rgb = pixes[i*smallWidth+j];
                int r = (rgb & 0xff0000) >> 16;
                int g = (rgb & 0xff00) >> 8;
                int b = (rgb & 0xff);
                double gray = Double.valueOf(r * 299 + g * 587 + b * 114 + 500)/255000.0;

                grayMatrix[++index] = gray;
            }
        }
        return grayMatrix;
    }

    public int[] transGrayToBinaryValue(double[] input){
        int[] binaryArray = new int[input.length];
        for(int i=0;i<input.length;i++){
            if(Double.compare(0.7, input[i]) >= 0){
                binaryArray[i] = 1;
            }else{
                binaryArray[i] = 0;
            }
        }
        return binaryArray;
    }

	public double[] getRasterMatrixFromCanvas(Canvas canvas) {
		Dimension imageSize = canvas.getSize();
        BufferedImage image = new BufferedImage(imageSize.width,imageSize.height, BufferedImage.TYPE_INT_RGB);
        Graphics2D graphics = image.createGraphics();
        canvas.paint(graphics);
        graphics.dispose();
        
        int[] pixes = new int[imageSize.width*imageSize.height];
        pixes = (int[])image.getRaster().getDataElements(0,0,imageSize.width,imageSize.height,pixes);
        double[] grayMatrix = new double[imageSize.width];
        int index = -1;
        
        for(int i=0;i<imageSize.height;i++){
            for(int j=0;j<imageSize.width;j++){
            	//System.out.println("JerryDebug:"+i+":"+j);
                int rgb = pixes[i*imageSize.width+j];
                int r = (rgb & 0xff0000) >> 16;
                int g = (rgb & 0xff00) >> 8;
                int b = (rgb & 0xff);
                double gray = Double.valueOf(r * 299 + g * 587 + b * 114 + 500)/255000.0;

                grayMatrix[(++index)%imageSize.width] += gray;
            }
        }
        
        for (int i=0;i<grayMatrix.length;i++){
        	grayMatrix[i] = grayMatrix[i]/imageSize.height;
        }
        return grayMatrix;
	}
}
